public class CountValidSelections {
    public static int countValidSelections(int[] nums) {
        int[] pre = new int[nums.length];
        int[] suf = new int[nums.length];
        for(int i = 1; i < nums.length; i++) {
            pre[i] = pre[i - 1] + nums[i-1];
        }
        for(int i = nums.length - 2; i >= 0; i--) {
            suf[i] = suf[i + 1] + nums[i+1];
        }
        int ans = 0;
        for(int i = 0; i < nums.length; i++){
          if(nums[i] == 0 && Math.abs(pre[i] - suf[i]) <= 1){
            ans+= Math.abs(pre[i] - suf[i]) == 0 ? 2 : 1;
          }
        }
        return ans;
    }
    
    public static void main(String[] args) {
        int[] nums = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
        System.out.println(countValidSelections(nums));
    }
}
